{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# VGG\n",
    "计算机视觉是一直深度学习的主战场，从这里我们将接触到近几年非常流行的卷积网络结构，网络结构由浅变深，参数越来越多，网络有着更多的跨层链接，首先我们先介绍一个数据集 cifar10，我们将以此数据集为例介绍各种卷积网络的结构。\n",
    "\n",
    "## CIFAR 10\n",
    "cifar 10 这个数据集一共有 50000 张训练集，10000 张测试集，两个数据集里面的图片都是 png 彩色图片，图片大小是 32 x 32 x 3，一共是 10 分类问题，分别为飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。这个数据集是对网络性能测试一个非常重要的指标，可以说如果一个网络在这个数据集上超过另外一个网络，那么这个网络性能上一定要比另外一个网络好，目前这个数据集最好的结果是 95% 左右的测试集准确率。\n",
    "\n",
    "![](https://ws1.sinaimg.cn/large/006tNc79ly1fmpjxxq7wcj30db0ae7ag.jpg)\n",
    "\n",
    "你能用肉眼对这些图片进行分类吗？\n",
    "\n",
    "cifar 10 已经被 pytorch 内置了，使用非常方便，只需要调用 `torchvision.datasets.CIFAR10` 就可以了"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## VGGNet\n",
    "vggNet 是第一个真正意义上的深层网络结构，其是 ImageNet2014年的冠军，得益于 python 的函数和循环，我们能够非常方便地构建重复结构的深层网络。\n",
    "\n",
    "vgg 的网络结构非常简单，就是不断地堆叠卷积层和池化层，下面是一个简单的图示\n",
    "\n",
    "![](https://ws4.sinaimg.cn/large/006tNc79ly1fmpk5smtidj307n0dx3yv.jpg)\n",
    "\n",
    "vgg 几乎全部使用 3 x 3 的卷积核以及 2 x 2 的池化层，使用小的卷积核进行多层的堆叠和一个大的卷积核的感受野是相同的，同时小的卷积核还能减少参数，同时可以有更深的结构。\n",
    "\n",
    "vgg 的一个关键就是使用很多层 3 x 3 的卷积然后再使用一个最大池化层，这个模块被使用了很多次，下面我们照着这个结构来写一写"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T09:01:51.296457Z",
     "start_time": "2017-12-22T09:01:50.883050Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import nn\n",
    "from torch.autograd import Variable\n",
    "from torchvision.datasets import CIFAR10"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以定义一个 vgg 的 block，传入三个参数，第一个是模型层数，第二个是输入的通道数，第三个是输出的通道数，第一层卷积接受的输入通道就是图片输入的通道数，然后输出最后的输出通道数，后面的卷积接受的通道数就是最后的输出通道数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T09:01:51.312500Z",
     "start_time": "2017-12-22T09:01:51.298777Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def vgg_block(num_convs, in_channels, out_channels):\n",
    "    net = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(True)] # 定义第一层\n",
    "    \n",
    "    for i in range(num_convs-1): # 定义后面的很多层\n",
    "        net.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1))\n",
    "        net.append(nn.ReLU(True))\n",
    "        \n",
    "    net.append(nn.MaxPool2d(2, 2)) # 定义池化层\n",
    "    return nn.Sequential(*net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以将模型打印出来看看结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T08:20:40.819497Z",
     "start_time": "2017-12-22T08:20:40.808853Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "  (1): ReLU(inplace)\n",
      "  (2): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "  (3): ReLU(inplace)\n",
      "  (4): Conv2d (128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "  (5): ReLU(inplace)\n",
      "  (6): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "block_demo = vgg_block(3, 64, 128)\n",
    "print(block_demo)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T07:52:04.632406Z",
     "start_time": "2017-12-22T07:52:02.381987Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 128, 150, 150])\n"
     ]
    }
   ],
   "source": [
    "# 首先定义输入为 (1, 64, 300, 300)\n",
    "input_demo = Variable(torch.zeros(1, 64, 300, 300))\n",
    "output_demo = block_demo(input_demo)\n",
    "print(output_demo.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到输出就变为了 (1, 128, 150, 150)，可以看到经过了这一个 vgg block，输入大小被减半，通道数变成了 128\n",
    "\n",
    "下面我们定义一个函数对这个 vgg block 进行堆叠"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T09:01:54.497712Z",
     "start_time": "2017-12-22T09:01:54.489255Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def vgg_stack(num_convs, channels):\n",
    "    net = []\n",
    "    for n, c in zip(num_convs, channels):\n",
    "        in_c = c[0]\n",
    "        out_c = c[1]\n",
    "        net.append(vgg_block(n, in_c, out_c))\n",
    "    return nn.Sequential(*net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "作为实例，我们定义一个稍微简单一点的 vgg 结构，其中有 8 个卷积层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T09:01:55.149378Z",
     "start_time": "2017-12-22T09:01:55.041923Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Sequential(\n",
      "  (0): Sequential(\n",
      "    (0): Conv2d (3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (1): ReLU(inplace)\n",
      "    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n",
      "  )\n",
      "  (1): Sequential(\n",
      "    (0): Conv2d (64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (1): ReLU(inplace)\n",
      "    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n",
      "  )\n",
      "  (2): Sequential(\n",
      "    (0): Conv2d (128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (1): ReLU(inplace)\n",
      "    (2): Conv2d (256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (3): ReLU(inplace)\n",
      "    (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n",
      "  )\n",
      "  (3): Sequential(\n",
      "    (0): Conv2d (256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (1): ReLU(inplace)\n",
      "    (2): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (3): ReLU(inplace)\n",
      "    (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n",
      "  )\n",
      "  (4): Sequential(\n",
      "    (0): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (1): ReLU(inplace)\n",
      "    (2): Conv2d (512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n",
      "    (3): ReLU(inplace)\n",
      "    (4): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), dilation=(1, 1))\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "vgg_net = vgg_stack((1, 1, 2, 2, 2), ((3, 64), (64, 128), (128, 256), (256, 512), (512, 512)))\n",
    "print(vgg_net)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以看到网络结构中有个 5 个 最大池化，说明图片的大小会减少 5 倍，我们可以验证一下，输入一张 256 x 256 的图片看看结果是什么"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T08:52:44.049650Z",
     "start_time": "2017-12-22T08:52:43.431478Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 512, 8, 8])\n"
     ]
    }
   ],
   "source": [
    "test_x = Variable(torch.zeros(1, 3, 256, 256))\n",
    "test_y = vgg_net(test_x)\n",
    "print(test_y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到图片减小了 $2^5$ 倍，最后再加上几层全连接，就能够得到我们想要的分类输出"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T09:01:57.323034Z",
     "start_time": "2017-12-22T09:01:57.306864Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "class vgg(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(vgg, self).__init__()\n",
    "        self.feature = vgg_net\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(512, 100),\n",
    "            nn.ReLU(True),\n",
    "            nn.Linear(100, 10)\n",
    "        )\n",
    "    def forward(self, x):\n",
    "        x = self.feature(x)\n",
    "        x = x.view(x.shape[0], -1)\n",
    "        x = self.fc(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "然后我们可以训练我们的模型看看在 cifar10 上的效果"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T09:01:59.921373Z",
     "start_time": "2017-12-22T09:01:58.709531Z"
    },
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "from utils import train\n",
    "\n",
    "def data_tf(x):\n",
    "    x = np.array(x, dtype='float32') / 255\n",
    "    x = (x - 0.5) / 0.5 # 标准化，这个技巧之后会讲到\n",
    "    x = x.transpose((2, 0, 1)) # 将 channel 放到第一维，只是 pytorch 要求的输入方式\n",
    "    x = torch.from_numpy(x)\n",
    "    return x\n",
    "     \n",
    "train_set = CIFAR10('./data', train=True, transform=data_tf)\n",
    "train_data = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)\n",
    "test_set = CIFAR10('./data', train=False, transform=data_tf)\n",
    "test_data = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False)\n",
    "\n",
    "net = vgg()\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=1e-1)\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2017-12-22T09:12:46.868967Z",
     "start_time": "2017-12-22T09:01:59.924086Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 0. Train Loss: 2.303118, Train Acc: 0.098186, Valid Loss: 2.302944, Valid Acc: 0.099585, Time 00:00:32\n",
      "Epoch 1. Train Loss: 2.303085, Train Acc: 0.096907, Valid Loss: 2.302762, Valid Acc: 0.100969, Time 00:00:33\n",
      "Epoch 2. Train Loss: 2.302916, Train Acc: 0.097287, Valid Loss: 2.302740, Valid Acc: 0.099585, Time 00:00:33\n",
      "Epoch 3. Train Loss: 2.302395, Train Acc: 0.102042, Valid Loss: 2.297652, Valid Acc: 0.108782, Time 00:00:32\n",
      "Epoch 4. Train Loss: 2.079523, Train Acc: 0.202026, Valid Loss: 1.868179, Valid Acc: 0.255736, Time 00:00:31\n",
      "Epoch 5. Train Loss: 1.781262, Train Acc: 0.307625, Valid Loss: 1.735122, Valid Acc: 0.323279, Time 00:00:31\n",
      "Epoch 6. Train Loss: 1.565095, Train Acc: 0.400975, Valid Loss: 1.463914, Valid Acc: 0.449565, Time 00:00:31\n",
      "Epoch 7. Train Loss: 1.360450, Train Acc: 0.495225, Valid Loss: 1.374488, Valid Acc: 0.490803, Time 00:00:31\n",
      "Epoch 8. Train Loss: 1.144470, Train Acc: 0.585758, Valid Loss: 1.384803, Valid Acc: 0.524624, Time 00:00:31\n",
      "Epoch 9. Train Loss: 0.954556, Train Acc: 0.659287, Valid Loss: 1.113850, Valid Acc: 0.609968, Time 00:00:32\n",
      "Epoch 10. Train Loss: 0.801952, Train Acc: 0.718131, Valid Loss: 1.080254, Valid Acc: 0.639933, Time 00:00:31\n",
      "Epoch 11. Train Loss: 0.665018, Train Acc: 0.765945, Valid Loss: 0.916277, Valid Acc: 0.698972, Time 00:00:31\n",
      "Epoch 12. Train Loss: 0.547411, Train Acc: 0.811241, Valid Loss: 1.030948, Valid Acc: 0.678896, Time 00:00:32\n",
      "Epoch 13. Train Loss: 0.442779, Train Acc: 0.846228, Valid Loss: 0.869791, Valid Acc: 0.732496, Time 00:00:32\n",
      "Epoch 14. Train Loss: 0.357279, Train Acc: 0.875440, Valid Loss: 1.233777, Valid Acc: 0.671677, Time 00:00:31\n",
      "Epoch 15. Train Loss: 0.285171, Train Acc: 0.900096, Valid Loss: 0.852879, Valid Acc: 0.765131, Time 00:00:32\n",
      "Epoch 16. Train Loss: 0.222431, Train Acc: 0.923374, Valid Loss: 1.848096, Valid Acc: 0.614023, Time 00:00:31\n",
      "Epoch 17. Train Loss: 0.174834, Train Acc: 0.939478, Valid Loss: 1.137286, Valid Acc: 0.728639, Time 00:00:31\n",
      "Epoch 18. Train Loss: 0.144375, Train Acc: 0.950587, Valid Loss: 0.907310, Valid Acc: 0.776800, Time 00:00:31\n",
      "Epoch 19. Train Loss: 0.115332, Train Acc: 0.960878, Valid Loss: 1.009886, Valid Acc: 0.761175, Time 00:00:31\n"
     ]
    }
   ],
   "source": [
    "train(net, train_data, test_data, 20, optimizer, criterion)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以看到，跑完 20 次，vgg 能在 cifar 10 上取得 76% 左右的测试准确率"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
